#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Copyright (c) 2019 Battelle Energy Alliance, LLC.  All rights reserved.

from __future__ import print_function

import sys
import os
import re
import argparse
import struct
import ipaddress
import itertools
import pprint
import uuid
from collections import defaultdict

UNSPECIFIED_TAG = '<~<~<none>~>~>'
HOST_LIST_IDX = 0
SEGMENT_LIST_IDX = 1

###################################################################################################
# print to stderr
def eprint(*args, **kwargs):
  print(*args, file=sys.stderr, **kwargs)

###################################################################################################
# main
def main():

  # extract arguments from the command line
  # print (sys.argv[1:]);
  parser = argparse.ArgumentParser(description='Logstash IP address to Segment Filter Creator', add_help=False, usage='ip-to-segment-logstash.py <arguments>')
  parser.add_argument('-s', '--segment', dest='segmentInput', metavar='<STR>', type=str, nargs='*', default='', help='Input segment mapping file(s)')
  parser.add_argument('-h', '--host', dest='hostInput', metavar='<STR>', type=str, nargs='*', default='', help='Input host mapping file(s)')
  parser.add_argument('-o', '--output', dest='output', metavar='<STR>', type=str, default='-', help='Output file')
  try:
    parser.error = parser.exit
    args = parser.parse_args()
  except SystemExit:
    parser.print_help()
    exit(2)

  # read segment input files into a single list, and host input files into another
  segmentLines = []
  hostLines = []

  for inFile in args.segmentInput:
    if os.path.isfile(inFile):
      segmentLines.extend([line.strip() for line in open(inFile)])

  for inFile in args.hostInput:
    if os.path.isfile(inFile):
      hostLines.extend([line.strip() for line in open(inFile)])

  # remove comments
  segmentLines = list(filter(lambda x: (len(x) > 0) and (not x.startswith('#')), segmentLines))
  hostLines = list(filter(lambda x: (len(x) > 0) and (not x.startswith('#')), hostLines))

  if (len(segmentLines) > 0) or (len(hostLines) > 0):

    filterId = 0
    addedFields = set()

    outFile = open(args.output, 'w+') if (args.output and args.output != '-') else sys.stdout
    try:
      print('filter {', file=outFile)
      print("", file=outFile)
      print("  # this file was automatically generated by {}".format(os.path.basename(__file__)), file=outFile)
      print("", file=outFile)

      # process segment mappings into a dictionary of two dictionaries of lists (one for hosts, one for segments)
      # eg., tagListMap[required tag name][HOST_LIST_IDX|SEGMENT_LIST_IDX][network segment name] = [172.16.0.0/12, 192.168.0.0/24, 10.0.0.41]
      tagListMap = defaultdict(lambda: [defaultdict(list), defaultdict(list)])

      # handle segment mappings
      for line in segmentLines:
        # CIDR to network segment format:
        #   IP(s)|segment name|required tag
        #
        # where:
        #   IP(s): comma-separated list of CIDR-formatted network IP addresses
        #          eg., 10.0.0.0/8, 169.254.0.0/16, 172.16.10.41
        #
        #   segment name: segment name to be assigned when event IP address(es) match
        #
        #   required tag (optional): only check match and apply segment name if the event
        #                            contains this tag
        values = [x.strip() for x in line.split('|')]
        if len(values) >= 2:
          networkList = []
          for ip in ''.join(values[0].split()).split(','):
            try:
              networkList.append(str(ipaddress.ip_network(unicode(ip))).lower() if ('/' in ip) else str(ipaddress.ip_address(unicode(ip))).lower())
            except ValueError:
              eprint('"{}" is not a valid IP address, ignoring'.format(ip))
          segmentName = values[1]
          tagReq = values[2] if ((len(values) >= 3) and (len(values[2]) > 0)) else UNSPECIFIED_TAG
          if (len(networkList) > 0) and (len(segmentName) > 0):
            tagListMap[tagReq][SEGMENT_LIST_IDX][segmentName].extend(networkList)
          else:
            eprint('"{}" is not formatted correctly, ignoring'.format(line))
        else:
          eprint('"{}" is not formatted correctly, ignoring'.format(line))

      # handle hostname mappings
      macAddrRegex = re.compile(r'([a-fA-F0-9]{2}[:|\-]?){6}')
      for line in hostLines:
        # IP or MAC address to host name map:
        #   address|host name|required tag
        #
        # where:
        #   address: comma-separated list of IPv4, IPv6, or MAC addresses
        #          eg., 172.16.10.41, 02:42:45:dc:a2:96, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
        #
        #   host name: host name to be assigned when event address(es) match
        #
        #   required tag (optional): only check match and apply host name if the event
        #                            contains this tag
        #
        values = [x.strip() for x in line.split('|')]
        if len(values) >= 2:
          addressList = []
          for addr in ''.join(values[0].split()).split(','):
            try:
              # see if it's an IP address
              addressList.append(str(ipaddress.ip_address(unicode(addr))).lower())
            except ValueError:
              # see if it's a MAC address
              if re.match(macAddrRegex, addr):
                # prepend _ temporarily to distinguish a mac address
                addressList.append("_{}".format(addr.replace('-', ':').lower()))
              else:
                eprint('"{}" is not a valid IP or MAC address, ignoring'.format(ip))
          hostName = values[1]
          tagReq = values[2] if ((len(values) >= 3) and (len(values[2]) > 0)) else UNSPECIFIED_TAG
          if (len(addressList) > 0) and (len(hostName) > 0):
            tagListMap[tagReq][HOST_LIST_IDX][hostName].extend(addressList)
          else:
            eprint('"{}" is not formatted correctly, ignoring'.format(line))
        else:
          eprint('"{}" is not formatted correctly, ignoring'.format(line))

      # go through the lists of segments/hosts, which will now be organized by required tag first, then
      # segment/host name, then the list of addresses
      for tag, nameMaps in tagListMap.iteritems():
        print("", file=outFile)

        # if a tag name is specified, print the IF statement verifying the tag's presence
        if tag != UNSPECIFIED_TAG:
          print('  if ("{}" in [tags]) {{'.format(tag), file=outFile)
        try:

          # for the host names(s) to be checked, create two filters, one for source IP|MAC and one for dest IP|MAC
          for hostName, addrList in nameMaps[HOST_LIST_IDX].iteritems():

            # ip addresses mapped to hostname
            ipList = [a for a in addrList if not a.startswith('_')]
            if (len(ipList) >= 1):
              for source in ['orig', 'resp']:
                filterId += 1
                fieldName = "{}_h".format(source)
                newFieldName = "{}_hostname".format(source)
                print("", file=outFile)
                print('    if ([zeek][{}]) and ({}) {{ '.format(fieldName, ' or '.join(['([zeek][{}] == "{}")'.format(fieldName, ip) for ip in ipList])), file=outFile)
                print('      mutate {{ id => "mutate_add_autogen_{}_ip_hostname_{}"'.format(source, filterId), file=outFile)
                print('        add_field => {{ "[zeek][{}]" => "{}" }}'.format(newFieldName, hostName), file=outFile)
                print("      }", file=outFile)
                print("    }", file=outFile)
                addedFields.add("[zeek][{}]".format(newFieldName))

            # mac addresses mapped to hostname
            macList = [a for a in addrList if a.startswith('_')]
            if (len(macList) >= 1):
              for source in ['orig', 'resp']:
                filterId += 1
                fieldName = "{}_l2_addr".format(source)
                newFieldName = "{}_hostname".format(source)
                print("", file=outFile)
                print('    if ([zeek][{}]) and ({}) {{ '.format(fieldName, ' or '.join(['([zeek][{}] == "{}")'.format(fieldName, mac[1:]) for mac in macList])), file=outFile)
                print('      mutate {{ id => "mutate_add_autogen_{}_mac_hostname_{}"'.format(source, filterId), file=outFile)
                print('        add_field => {{ "[zeek][{}]" => "{}" }}'.format(newFieldName, hostName), file=outFile)
                print("      }", file=outFile)
                print("    }", file=outFile)
                addedFields.add("[zeek][{}]".format(newFieldName))

          # for the segment(s) to be checked, create two cidr filters, one for source IP and one for dest IP
          for segmentName, ipList in nameMaps[SEGMENT_LIST_IDX].iteritems():
            for source in ['orig', 'resp']:
              filterId += 1
              # ip addresses/ranges mapped to network segment names
              fieldName = "{}_h".format(source)
              newFieldName = "{}_segment".format(source)
              print("", file=outFile)
              print("    if ([zeek][{}]) {{ cidr {{".format(fieldName), file=outFile)
              print('      id => "cidr_autogen_{}_segment_{}"'.format(source, filterId), file=outFile)
              print('      address => [ "%{{[zeek][{}]}}" ]'.format(fieldName), file=outFile)
              print('      network => [ {} ]'.format(', '.join('"{}"'.format(ip) for ip in ipList)), file=outFile)
              print('      add_tag => [ "{}" ]'.format(segmentName), file=outFile)
              print('      add_field => {{ "[zeek][{}]" => "{}" }}'.format(newFieldName, segmentName), file=outFile)
              print("    } }", file=outFile)
              addedFields.add("[zeek][{}]".format(newFieldName))

        finally:
          # if a tag name is specified, close the IF statement verifying the tag's presence
          if tag != UNSPECIFIED_TAG:
            print("", file=outFile)
            print('  }} # end (if "{}" in [tags])'.format(tag), file=outFile)

    finally:
      # deduplicate any added fields
      if addedFields:
        print("", file=outFile)
        print('  # deduplicate any added fields', file=outFile)
        for field in list(itertools.product(['orig', 'resp'], ['hostname', 'segment'])):
          newFieldName = "[zeek][{}_{}]".format(field[0], field[1])
          if newFieldName in addedFields:
            print("", file=outFile)
            print('  if ({}) {{ '.format(newFieldName), file=outFile)
            print('    ruby {{ id => "ruby{}deduplicate"'.format(''.join(c for c, _ in itertools.groupby(re.sub('[^0-9a-zA-Z]+', '_', newFieldName)))), file=outFile)
            print('      code => "', file=outFile)
            print("        fieldVals = event.get('{}')".format(newFieldName), file=outFile)
            print("        if fieldVals.kind_of?(Array) then event.set('{}', fieldVals.uniq) end".format(newFieldName), file=outFile)
            print('      "', file=outFile)
            print('  } }', file=outFile)

      # close out filter with ending }
      print("", file=outFile)
      print('} # end Filter', file=outFile)

    if outFile is not sys.stdout:
      outFile.close()

if __name__ == '__main__':
  main()